package org.framework.lazy.cloud.network.heartbeat.common;

import io.netty.channel.Channel;
import lombok.extern.slf4j.Slf4j;
import org.framework.lazy.cloud.network.heartbeat.common.utils.ChannelAttributeKeyUtils;
import org.wu.framework.core.utils.ObjectUtils;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;

/**
 * 通道上下文
 */
@Slf4j
public class ChannelContext {

    private final static
    ConcurrentHashMap<String/*clientId*/, List<Channel>/*通道*/>
            cacheClientChannelConcurrentHashMap = new ConcurrentHashMap<>();

    /**
     * 新增通道
     *
     * @param channel  通道
     * @param clientId 客户端ID
     */
    public static void push(Channel channel, String clientId) {
        // 如果服务端已经存在 移除
        if (cacheClientChannelConcurrentHashMap.containsKey(clientId)) {
            List<Channel> existChannelList = new ArrayList<>();
            List<Channel> oldChannels = cacheClientChannelConcurrentHashMap.get(clientId);
            for (Channel existChannel : oldChannels) {
                if (existChannel != null) {
                    if (existChannel.isActive()) {
                        existChannelList.add(existChannel);
                    } else {
                        log.warn("close channel with  client:{}", ChannelAttributeKeyUtils.getClientId(existChannel));
                        existChannel.close();
                    }
                }
            }
            existChannelList.add(channel);
            cacheClientChannelConcurrentHashMap.put(clientId, existChannelList);
        } else {
            cacheClientChannelConcurrentHashMap.putIfAbsent(clientId, Collections.synchronizedList(new ArrayList<>(List.of(channel))));
        }

    }

    /**
     * 新增通道
     *
     * @param channel  通道
     * @param clientId 客户端ID
     */
    public static void push(Channel channel,  byte[] clientId) {
        push(channel, new String(clientId, StandardCharsets.UTF_8));
    }

    /**
     * 获取指定服务端所有通道
     *
     * @return 返回所有通道信息
     */
    public static ConcurrentMap<String/*clientId*/, List<Channel>/*通道*/> getChannels() {
        return cacheClientChannelConcurrentHashMap;
    }

    /**
     * 获取所有通道
     *
     * @return 返回所有通道信息
     */
    public static List<String> getClientIds() {
        return new ArrayList<>(cacheClientChannelConcurrentHashMap.keySet().stream().toList());
    }


    /**
     * 根据通道ID获取通道信息
     *
     * @param clientId 客户端ID
     * @return 通道信息
     */
    public static List<Channel> get(byte[] clientId) {
        try {
            return cacheClientChannelConcurrentHashMap.get(new String(clientId, StandardCharsets.UTF_8));
        } catch (Exception e) {
            e.printStackTrace();
            // 无法通过客户端ID[{}]获取通道信息
            log.error("Unable to obtain channel information through  client ID [{}]", new String(clientId));
            return null;
        }

    }

    /**
     * 根据通道ID获取通道信息
     *
     * @param clientId 客户端ID
     * @return 通道信息
     */
    public static List<Channel> get(String clientId) {
        return get(clientId.getBytes(StandardCharsets.UTF_8));
    }

    /**
     * 根据通道ID获取通道信息
     *
     * @param clientId 客户端ID
     * @return 通道信息
     */
    public static Channel getLoadBalance(byte[] clientId) {
        List<Channel> channels = get(clientId);
        if (ObjectUtils.isEmpty(channels)) {
            return null;
        }
        channels = channels.stream().filter(Channel::isActive).collect(Collectors.toList());
        if (ObjectUtils.isEmpty(channels)) {
            return null;
        }
        // TODO  负载问题
        return channels.get(0);
    }

    /**
     * 根据通道ID获取通道信息
     *
     * @param clientId 客户端ID
     * @return 通道信息
     */
    public static Channel getLoadBalance(String clientId) {
        return getLoadBalance(clientId.getBytes(StandardCharsets.UTF_8));
    }

    /**
     * 关闭通道
     *
     * @param clientId 客户端ID
     */
    public static void clear(String clientId) {
        List<Channel> channels = get(clientId);
        if (channels != null) {
            remove(clientId);
            for (Channel channel : channels) {
                if (channel != null && channel.isActive()) {
                    channel.close();
                }
            }
        } else {
            // log warm
            // 无法通过客户ID:[{}]移除客户端
            log.warn("Unable to remove client through clientId: [{}]", clientId);
        }
    }

    /**
     * 通过客户端ID移除客户端通道
     *
     * @param clientId 客户端ID
     */
    public static void remove(byte[] clientId) {
        List<Channel> clientChannel = get(clientId);
        if (clientChannel != null) {
            cacheClientChannelConcurrentHashMap.remove(new String(clientId, StandardCharsets.UTF_8));
        } else {
            // log warm 无法通过客户ID:[{}]移除客户端
            log.warn("Unable to remove client through clientId: [{}]", new String(clientId));
        }
    }

    /**
     * 通过客户端ID移除客户端通道
     *
     * @param clientId 客户端ID
     */
    public static void remove(String clientId) {
        List<Channel> clientChannel = get(clientId);
        if (clientChannel != null) {
            cacheClientChannelConcurrentHashMap.remove(clientId);
        } else {
            // log warm 无法通过客户ID:[{}]移除客户端
            log.warn("Unable to remove client through  clientId: 【{}】", clientId);
        }
    }


    /**
     * 客户端通道信息
     */
    public interface ClientChannel {

        /**
         * 客户端ID
         */
        byte[] getClientId();


        /**
         * 通道
         */
        Channel getChannel();

    }

}

